/* vim: set ft=cpp:
 *
 * Copyright (c) 2018- Apple Computer, Inc. All rights reserved.
 *
 * @APPLE_LICENSE_HEADER_START@
 *
 * The contents of this file constitute Original Code as defined in and
 * are subject to the Apple Public Source License Version 2.0 (the
 * "License").  You may not use this file except in compliance with the
 * License.  Please obtain a copy of the License at
 * http://www.apple.com/publicsource and read it before using this file.
 *
 * This Original Code and all software distributed under the License are
 * distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, EITHER
 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE OR NON-INFRINGEMENT.  Please see the
 * License for the specific language governing rights and limitations
 * under the License.
 *
 * @APPLE_LICENSE_HEADER_END@
 */

// Since IOKit doesn't support stl for good reason, we would like to
// cherry-pick the more useful c++11 and after template. In particular the
// smart pointer stuff. This file is an OSObject specialization of
// std::unique_ptr, std::shared_ptr and std::weak_ptr

#ifndef IOG_TL_OSMEMORY
#define IOG_TL_OSMEMORY

#if KERNEL
#include <kern/assert.h>
#include <kern/debug.h>
#else
#include <cassert>
#endif

#include <_config>
#include <osutility>

#include <os/refcnt.h>

#include <libkern/c++/OSObject.h>

_IOG_START_NAMESPACE

template<class _T> class OSUniqueObject;
template<class _T> class OSSharedObject;
template<class _T> class OSWeakObject;

// std::unique_ptr<T> equivalent
template<class _T>
class OSUniqueObject
{
public:
    using pointer = _T*;
    using element_type = _T;

private:
    pointer fTypedObj = nullptr;

    template <class _Y> friend class OSSharedObject;

    void impl_release()
    {
        OSSafeReleaseNULL(fTypedObj);
    }

    template <class _Y>
    OSUniqueObject& impl_opassign(OSUniqueObject<_Y>&& other)
    {
        if (this != &other)
        {
            impl_release();
            fTypedObj = other.fTypedObj;
            other.fTypedObj = nullptr;
        }
        return *this;
    }

public:
    OSUniqueObject() = default;
    ~OSUniqueObject()
    {
        OSSafeReleaseNULL(fTypedObj);
    }

    // Just like std::unique_ptr, takes ownership of p
    OSUniqueObject(pointer p) : fTypedObj(p) { }

    explicit OSUniqueObject(nullptr_t /* np */) {}
    OSUniqueObject& operator=(nullptr_t /* np */)
    {
        impl_release();
        return *this;
    }

    // Move
    OSUniqueObject(OSUniqueObject&& other) : fTypedObj(other.fTypedObj)
        { other.fTypedObj = nullptr; }
    template <class _Y>
    OSUniqueObject(OSUniqueObject<_Y>&& other) : fTypedObj(other.fTypedObj)
        { other.fTypedObj = nullptr; }
    OSUniqueObject& operator=(OSUniqueObject&& other)
        { return impl_opassign(_VIOG::move(other)); }
    template <class _Y>
    OSUniqueObject& operator=(OSUniqueObject<_Y>&& other)
        { return impl_opassign(_VIOG::move(other)); }

    // Copy is illegal
    OSUniqueObject(const OSUniqueObject&) = delete;
    template <class _Y>
    OSUniqueObject(const OSUniqueObject<_Y>& other) = delete;
    OSUniqueObject& operator=(const OSUniqueObject& other) = delete;
    template <class _Y>
    OSUniqueObject& operator=(const OSUniqueObject<_Y>& other) = delete;

    // Member functions
    pointer get() const             { return  fTypedObj; }
    pointer operator->() const      { return  fTypedObj; }
    _T& operator*() const           { return *fTypedObj; }
    operator bool() const           { return static_cast<bool>(get()); }
    void reset(pointer p = nullptr) { *this = OSUniqueObject(p); }

    void swap(OSUniqueObject& other)
    {
        _VIOG::swap(fTypedObj, other.fTypedObj);
    }
}; // OSUniqueObject<_T>

// OSUniqueObject non member funtions
template<class _T>
void swap(OSUniqueObject<_T>& lhs, OSUniqueObject<_T>& rhs) { lhs.swap(rhs); }

// std::unique_ptr performs '==' operations by comparing get(). But that is
// just crazy. If two unique_ptrs are different objects and have the same
// pointer then our uniqueness assumption is violated.
template <class _T, class _Y>
bool operator==(const OSUniqueObject<_T>& lhs, const OSUniqueObject<_Y>& rhs)
    { return lhs.get() == rhs.get(); }
template<class _T, class _Y>
bool operator!=(const OSUniqueObject<_T>& lhs, const OSUniqueObject<_Y>& rhs)
    { return !(lhs == rhs); }

// Null pointer comparison specializations
template<class _T>
bool operator==(const OSUniqueObject<_T>& lhs, _VIOG::nullptr_t)
    { return !static_cast<bool>(lhs); }
template<class _T>
bool operator==(_VIOG::nullptr_t, const OSUniqueObject<_T>& rhs)
    { return !static_cast<bool>(rhs); }
template<class _T>
bool operator!=(const OSUniqueObject<_T>& lhs, _VIOG::nullptr_t)
    { return static_cast<bool>(lhs); }
template<class _T>
bool operator!=(_VIOG::nullptr_t, const OSUniqueObject<_T>& rhs)
    { return static_cast<bool>(rhs); }


// Shared and Weak object helper class
class _OSSharedWeakCounter
{
    template <class _Y> friend class OSSharedObject;
    template <class _Y> friend class OSWeakObject;

    OSObject* fObj;
    os_refcnt fUseCount;
    os_refcnt fWeakCount;

    explicit _OSSharedWeakCounter(OSObject* obj);
    ~_OSSharedWeakCounter();

    bool retain_shared();
    void release_shared();
    long count_shared() const;
    void retain_weak();
    void release_weak();

    _OSSharedWeakCounter* lock()
    {
        return retain_shared() ? this : nullptr;
    }
}; // class _OSSharedWeakCounter

// std::shared_ptr<T> equivalent
template<class _T>
class OSSharedObject
{
public:
    using element_type = _T;
    using pointer = _T*;
    using weak_type = OSWeakObject<_T>;

private:
    _OSSharedWeakCounter* fControl = nullptr;
    pointer fTypedObj = nullptr;

    template <class _Y> friend class OSSharedObject;
    template <class _Y> friend class OSWeakObject;

    void impl_zero()
    {
        fControl = nullptr;
        fTypedObj = nullptr;
    }
    void impl_setme(const OSSharedObject& from)
    {
        fControl  = from.fControl;
        fTypedObj = from.fTypedObj;
    }
    template <class _Y>
    void impl_setme(const OSSharedObject<_Y>& from)
    {
        fControl  = from.fControl;
        fTypedObj = from.fTypedObj;
    }
    void impl_saferetain()
    {
        if (static_cast<bool>(fControl))
            fControl->retain_shared();
    }
    void impl_saferelease()
    {
        if (static_cast<bool>(fControl))
            fControl->release_shared();
    }

public:
    // Basic constructor destructors

    // Just like std::shared_ptr, takes ownership of p
    OSSharedObject(_T* tp)
        : fControl(new _OSSharedWeakCounter(tp)), fTypedObj(tp) {}
    template<class _Y>
    OSSharedObject(_Y* yp)
    : fControl(new _OSSharedWeakCounter(yp)), fTypedObj(yp) {}

    OSSharedObject() = default;
    ~OSSharedObject()
    {
        impl_saferelease();
        impl_zero();
    }

    // Move
    OSSharedObject(OSSharedObject&& other)
    {
        impl_setme(other);
        other.impl_zero();
    }
    template<class _Y>
    OSSharedObject(OSSharedObject<_Y>&& other)
    {
        impl_setme(other);
        other.impl_zero();
    }
    OSSharedObject& operator=(OSSharedObject&& other)
    {
        if (this != &other)
        {
            impl_saferelease();
            impl_setme(other);
            other.impl_zero();
        }
        return *this;
    }
    template<class _Y>
    OSSharedObject& operator=(OSSharedObject<_Y>&& other)
    {
        if (this != &other)
        {
            impl_saferelease();
            impl_setme(other);
            other.impl_zero();
        }
        return *this;
    }

    // Copy
    OSSharedObject(const OSSharedObject& other)
    {
        impl_setme(other);
        impl_saferetain();
    }
    template<class _Y>
    OSSharedObject(const OSSharedObject<_Y>& other)
    {
        impl_setme(other);
        impl_saferetain();
    }
    template<class _Y>
    OSSharedObject(const OSSharedObject<_Y>& other, element_type* p)
        : fControl(other.fControl)  // Aliasing copy
    {
        if (static_cast<bool>(fControl)) {
            fControl->retain_shared();
            fTypedObj = p;
        }
    }
    OSSharedObject& operator=(const OSSharedObject& other)
    {
        if (this != &other)
        {
            impl_saferelease();
            impl_setme(other);
            impl_saferetain();
        }
        return *this;
    }

    template<class _Y>
    OSSharedObject& operator=(const OSSharedObject<_Y>& other)
    {
        if (this != &other)
        {
            impl_saferelease();
            impl_setme(other);
            impl_saferetain();
        }
        return *this;
    }


    pointer get() const        { return  fTypedObj; }
    pointer operator->() const { return  fTypedObj; }
    _T& operator*() const      { return *fTypedObj; }
    operator bool() const      { return static_cast<bool>(get()); }
    void reset()               { *this = OSSharedObject(); };
    long use_count() const
    {
        return static_cast<bool>(fControl)? fControl->count_shared() : 0;
    }

    void swap(OSSharedObject& other)
    {
        _VIOG::swap(fControl,  other.fControl);
        _VIOG::swap(fTypedObj, other.fTypedObj);
    }
};

// OSSharedObject non member funtions
template<class _T>
void swap(OSSharedObject<_T>& lhs, OSSharedObject<_T>& rhs) { lhs.swap(rhs); }

template<class _T, class _Y>
OSSharedObject<_T> static_pointer_cast(const OSSharedObject<_Y>& other)
{
    auto p = static_cast<typename OSSharedObject<_T>::element_type*>(
            other.get());
    return OSSharedObject<_T>(other, p);
}

template<class _T, class _Y>
OSSharedObject<_T> dynamic_pointer_cast(const OSSharedObject<_Y>& other)
{

    auto p = OSDynamicCast(_T, other.get());
    if (static_cast<bool>(p))
        return OSSharedObject<_T>(other, p);
    else
        return OSSharedObject<_T>();
}


template<class _T, class _Y>
OSSharedObject<_T> const_pointer_cast(const OSSharedObject<_Y>& other)
{
    auto p = const_cast<typename OSSharedObject<_T>::element_type*>(
            other.get());
    return OSSharedObject<_T>(other, p);
}

template<class _T, class _Y>
OSSharedObject<_T> reinterpret_pointer_cast(const OSSharedObject<_Y>& other)
{
    auto p = reinterpret_cast<typename OSSharedObject<_T>::element_type*>(
            other.get());
    return OSSharedObject<_T>(other, p);
}

template <class _T, class _Y>
bool operator==(const OSSharedObject<_T>& lhs, const OSSharedObject<_Y>& rhs)
    { return lhs.get() == rhs.get(); }
template<class _T, class _Y>
bool operator!=(const OSSharedObject<_T>& lhs, const OSSharedObject<_Y>& rhs)
    { return !(lhs == rhs); }
template<class _T>
bool operator==(const OSSharedObject<_T>& lhs, _VIOG::nullptr_t)
    { return !static_cast<bool>(lhs); }
template<class _T>
bool operator==(_VIOG::nullptr_t, const OSSharedObject<_T>& rhs)
    { return !static_cast<bool>(rhs); }
template<class _T>
bool operator!=(const OSSharedObject<_T>& lhs, _VIOG::nullptr_t)
    { return static_cast<bool>(lhs); }
template<class _T>
bool operator!=(_VIOG::nullptr_t, const OSSharedObject<_T>& rhs)
    { return static_cast<bool>(rhs); }

// std::weak_ptr<T> equivalent
template<class _T>
class OSWeakObject
{
public:
    using element_type = _T;

private:
    _OSSharedWeakCounter* fControl = nullptr;
    _T* fTypedObj = nullptr;

    template <class _Y> friend class OSSharedObject;
    template <class _Y> friend class OSWeakObject;

    void impl_zero()
    {
        fControl = nullptr;
        fTypedObj = nullptr;
    }
    void impl_saferetain()
    {
        if (static_cast<bool>(fControl))
            fControl->retain_weak();
    }
    void impl_saferelease()
    {
        if (static_cast<bool>(fControl))
            fControl->release_weak();
    }

public:
    OSWeakObject() = default;
    ~OSWeakObject()
    {
        impl_saferelease();
        impl_zero();
    }

    // Conversion from shared to weak
    template<class _Y>
    OSWeakObject(const OSSharedObject<_Y>& shared)
        : fControl(shared.fControl), fTypedObj(shared.fTypedObj)
    {
        impl_saferetain();
    }
    template<class _Y>
    OSWeakObject& operator=(const OSSharedObject<_Y>& shared)
    {
        if (this != &shared)
        {
            impl_saferelease();
            fControl = shared.fControl;
            fTypedObj = shared.fTypedObj;
            impl_saferetain();
        }
        return *this;
    }

    // Copy
    OSWeakObject(const OSWeakObject& other)
        : fControl(other.fControl), fTypedObj(other.fTypedObj)
    {
        impl_saferetain();
    }
    template<class _Y>
    OSWeakObject(const OSWeakObject<_Y>& other)
        : fControl(other.fControl), fTypedObj(other.fTypedObj)
    {
        impl_saferetain();
    }
    OSWeakObject& operator=(const OSWeakObject& other)
    {
        if (this != &other)
        {
            impl_saferelease();
            fControl = other.fControl;
            fTypedObj = other.fTypedObj;
            impl_saferetain();
        }
        return *this;
    }

    // Move
    OSWeakObject(OSWeakObject&& other)
        : fControl(other.fControl), fTypedObj(other.fTypedObj)
    {
        other.impl_zero();
    }
    template<class _Y>
    OSWeakObject(OSWeakObject<_Y>&& other)
        :fControl(other.fControl), fTypedObj(other.fTypedObj)
    {
        other.impl_zero();
    }
    OSWeakObject& operator=(OSWeakObject&& other)
    {
        if (this != &other)
        {
            impl_saferelease();
            fControl = other.fControl;
            fTypedObj = other.fTypedObj;
            other.impl_zero();
        }
        return *this;
    }

    void reset()           { *this = OSWeakObject(); };
    long use_count() const
    {
        return static_cast<bool>(fControl)? fControl->count_shared() : 0;
    }
    bool expired() const
    {
        return !static_cast<bool>(fControl) || fControl->count_shared() == 0;
    }
    OSSharedObject<_T> lock() const
    {
        OSSharedObject<_T> ret;
        ret.fControl = static_cast<bool>(fControl)? fControl->lock() : nullptr;
        if (static_cast<bool>(ret.fControl))
            ret.fTypedObj = fTypedObj;
        return ret;
    }

    void swap(OSWeakObject<_T>& other)
    {
        _VIOG::swap(fControl,  other.fControl);
        _VIOG::swap(fTypedObj, other.fTypedObj);
    }
};

// OSWeakObject non member funtions
template<class _T>
void swap(OSWeakObject<_T>& lhs, OSWeakObject<_T>& rhs) { lhs.swap(rhs); }

_IOG_END_NAMESPACE

#endif // !IOG_TL_OSMEMORY
